## Copyright 2007 Roee Shlomo, All Rights Reserved
# see LICENSE.txt for license information
import os
import socket
import sqlite3
try:
    from functools import reduce
except ImportError:
    pass
from time import time
from random import randrange
from defer import Deferred
from ktable import KTable
from krpc import KRPC
from RawServer import RawServer
from tokens import TokensHandler
from utility import encodeNodes, encodePeer
from knode import KNode
from khash import newIDInRange, newID
from utility import decodePeers
from actions import FindNode, GetValue, StoreValue
from errors import KrpcProtocolError, KrpcGenericError
from const import K, KEINITIAL_DELAY, KE_DELAY, KE_AGE, CHECKPOINT_INTERVAL, BUCKET_STALENESS, MIN_PING_INTERVAL, NULL_ID
from threading import Thread
from BitTornado.iprangeparse import is_valid_ip
from BitTornado.clock import clock

DEBUG = False

class Factory:
    """
    Background functionality
    """
    def __init__(self, host, port, dbDir, ipv6_enable = False, upnp = 0, natpmp = False):
        self.host = host
        self.port = port
        self.dbDir = dbDir
        self.store = None
        
        self.rawserver = RawServer(self, host, port, ipv6_enable, upnp, natpmp)
        self.tokensHandler = TokensHandler(self)
        self.krpc = KRPC(self)

        self.contacts = []  # [ip,] - contacts added by bittorrent
        self.announce = {}  # {ip:[hash,],} - nodes that announced us 
        
    def Node(self):
        """Create a new node"""
        return KNode(self)

    def start(self):
        """Start Factory"""
        self.rawserver.add_task(self._init)
        self.rawserver.add_task(self._checkpoint, 60)
        self.rawserver.start()
        self.rawserver.add_task(self._cleanDataBase, KEINITIAL_DELAY)
        self.rawserver.add_task(self.refreshTable, 5, [True])

    def _init(self):
        """Initialize Factory"""
        self._loadDB()
        self._loadSelfNode()
        self._loadRoutingTable()

    def _close(self):
        """Close Factory"""
        self._updateDB()
        self.rawserver.shutdown()
        
    ####################
    # Database Handler
    ####################
    def _loadDB(self):
        """Load the database"""
        if DEBUG:
            print("Debug: DHT - _loadDB")
        # connect
        self.store = sqlite3.connect(os.path.join(self.dbDir.encode("UTF-8"), "dht.db"))
        self.store.text_factory = str
        
        # create if neccacery
        c = self.store.cursor()
        statements = ["create table kv (key binary, value binary, age timestamp, primary key (key, value))",
                      "create table nodes (id binary primary key, host text, port number)",
                      "create table self (num number primary key, id binary, age timestamp)"]
        try:
            [c.execute(s) for s in statements]
        except sqlite3.OperationalError:
            pass
        else:
            self.store.commit()
        c.close()
        
    def _closeDB(self):
        """Close the database"""
        self.store.close()

    def _loadSelfNode(self):
        """Load the root node"""
        if DEBUG:
            print("Debug: DHT - loadSelfNode")

        # Get ID
        c = self.store.cursor()
        c.execute('select id, age from self where num = 0')
        data = c.fetchone()

        # Clean if too old
        if not data or time() - data[1] > 86400*5: # more than 5 days old
            id = newID()
            c.execute('delete from self')
            c.execute("insert into self values (0, ?, ?)", (sqlite3.Binary(id), time()))
            c.execute('delete from kv')
            c.execute('delete from nodes')
            self.store.commit()
        else:
            id = str(data[0])
        c.close()

        # Load self node
        self.node = self.Node().init(id, self.host, self.port)
   
    def _saveSelfNode(self):
        """Save the root node"""
        if DEBUG:
            print("Debug: DHT - saveSelfNode")
        c = self.store.cursor()
        c.execute('delete from self')
        c.execute("insert into self values (0, ?, ?)", (sqlite3.Binary(self.node.id), time()))
        self.store.commit()
        c.close()
        
    def _loadRoutingTable(self):
        """Load routing table from the database"""
        self.table = KTable(self.node)

        c = self.store.cursor()
        c.execute("select id, host, port from nodes")
        for row in c:
            n = self.Node().init(str(row[0]), row[1], row[2])
            self.table.insertNode(n, contacted = False)
        c.close()

        if DEBUG:
            print("Debug: DHT - nodes loaded:", self.stats())

    def _saveRoutingTable(self):
        """Save routing table nodes to the database"""
        if DEBUG:
            print("Debug: DHT - saveRoutingTable")
        c = self.store.cursor()
        c.execute("delete from nodes")
        for bucket in self.table.buckets:
            for node in bucket.l:
                c.execute("insert into nodes values (?, ?, ?)", (sqlite3.Binary(node.id), node.host, node.port))
        self.store.commit()
        c.close()
        
    def _updateDB(self):
        """Save info to database"""
        if DEBUG:
            print("Debug: DHT - updateDB")
        self._saveSelfNode()
        self._saveRoutingTable()
        if DEBUG:
            print("Debug: DHT - updateDB completed")

    def _flushExpired(self):
        """Clean old values from database"""
        c = self.store.cursor()
        c.execute("delete from kv where age > ?", (KE_AGE,))
        self.store.commit()
        c.close()
                                            
    def _retrieveValue(self, key):
        """Returns the value found for key in local table"""
        values = []
        c = self.store.cursor()
        c.execute("select value from kv where key = ?", (sqlite3.Binary(key),))
        for row in c:
            values.append(str(row[0]))
        c.close()
        return values[:20]
    
    def _storeValue(self, key, value):
        """Stores <key:value> pair in the database"""
        c = self.store.cursor()
        try:
            c.execute("insert into kv values (?, ?, ?);", (sqlite3.Binary(key), sqlite3.Binary(value), time()))
        except sqlite3.IntegrityError:
            c.execute("update kv set age = ? where key = ? and value = ?", (time(), sqlite3.Binary(key), sqlite3.Binary(value)))
        self.store.commit()
        c.close()

    ####################
    # Automatic updates
    ####################
    def _checkpoint(self):
        """Make some saving and refreshing once in a while"""
        if DEBUG:
            print("Debug: DHT - checkpoint")
        # Save DB to disk
        self._updateDB()
        # Find close nodes
        self.findCloseNodes()
        # Refresh Table
        self.refreshTable()
            
        self.rawserver.add_task(self._checkpoint, randrange(int(CHECKPOINT_INTERVAL * .9), int(CHECKPOINT_INTERVAL * 1.1)))
    
    def _cleanDataBase(self):
        self._flushExpired()
        self.rawserver.add_task(self._cleanDataBase, KE_DELAY)

    ####################
    # Interface
    ####################
    def addContact(self, host, port):
        """
        Ping this node and add the contact info to the table on pong!
        """
        # Validation
        if not isinstance(port, int):
            port = int(port)
        if not isinstance(host, str):
            host = str(host)
        if host in self.contacts:
            return False
        self.contacts.append(host)

        # Add Contact
        if is_valid_ip(host):
            n = self.Node().init(NULL_ID, host, port)
            self.sendPing(n)
        else:
            Thread(target = self.addRouterContact, args = [host, port]).start()
            
    def addRouterContact(self, host, port):
        try:
            host = socket.gethostbyname(host)
        except socket.error:
            return False
        n = self.Node().init(NULL_ID, host, port)
        self.sendPing(n)
        return True
    
    def insertNode(self, n, contacted = True):
        """
        Insert a node in our local table, pinging oldest contact in bucket, if necessary
        
        If all you have is a host/port, then use addContact, which calls this method after
        receiving the PONG from the remote node.  The reason for the seperation is we can't insert
        a node into the table without it's peer-ID.  That means of course the node passed into this
        method needs to be a properly formed Node object with a valid ID.
        """
        old = self.table.insertNode(n, contacted = contacted)
        if old and (clock() - old.lastSeen) > MIN_PING_INTERVAL and old.id != self.node.id:
            # the bucket is full, check to see if old node is still around and if so, replace it
            
            ## these are the callbacks used when we ping the oldest node in a bucket
            def _staleNodeHandler(oldnode=old, newnode = n):
                """ called if the pinged node never responds """
                self.table.replaceStaleNode(old, newnode)
            
            def _notStaleNodeHandler(dict, old=old):
                """ called when we get a pong from the old node """
                dict = dict['rsp']
                if dict['id'] == old.id:
                    self.table.justSeenNode(old.id)

            try:
                df = old.ping(self.node.id)
            except KrpcGenericError:
                _staleNodeHandler()
            else:
                df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)

    def findCloseNodes(self, callback=lambda a: None):
        """
        This does a findNode on the ID one away from our own.  
        This will allow us to populate our table with nodes on our network closest to our own.
        This is called as soon as we start up with an empty table
        """
        if DEBUG:
            print("Debug: DHT - findCloseNodes")
        id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
        self.findNode(id, callback)

    def refreshTable(self, force = False, callback=lambda a: None):
        """
        Refresh the table
        force=True will refresh table regardless of last bucket access time
        """
        if DEBUG:
            print("Debug: DHT - refreshTable")
        for bucket in self.table.buckets:
            if force or (clock() - bucket.lastAccessed >= BUCKET_STALENESS):
                id = newIDInRange(bucket.min, bucket.max)
                self.findNode(id, callback)

    def stats(self):
        """
        Returns the number of contacts in our routing table
        """
        return reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)        

    ####################
    # RPC Handler
    ####################
    def krpc_ping(self, id, _krpc_sender, **kwargs):
        """Incoming RPC: got ping"""
        if len(id) != 20:
            raise KrpcProtocolError("invalid id length: %d" % len(id))
        n = self.Node().init(id, *_krpc_sender)
        self.insertNode(n, contacted = False)
        return {"id" : self.node.id}
        
    def krpc_find_node(self, target, id, _krpc_sender, **kwargs):
        """Incoming RPC: got find_node"""
        if len(id) != 20:
            raise KrpcProtocolError("invalid id length: %d" % len(id))
        if len(target) != 20:
            raise KrpcProtocolError("invalid target id length: %d" % len(target))

        nodes = self.table.findNodes(target)
        nodes = map(lambda node: node.senderDict(), nodes)
        n = self.Node().init(id, *_krpc_sender)
        self.insertNode(n, contacted = False)
        return {"nodes" : encodeNodes(nodes), "id" : self.node.id}

    def krpc_get_peers(self, id, info_hash, _krpc_sender, **kwargs):
        """Incoming RPC: got get_peers"""
        if len(id) != 20:
            raise KrpcProtocolError("invalid id length: %d" % len(id))
        if len(info_hash) != 20:
            raise KrpcProtocolError("invalid info_hash length: %d" % len(info_hash))
        if id == NULL_ID:
            raise KrpcProtocolError("invalid id (NULL ID)")

        n = self.Node().init(id, *_krpc_sender)
        self.insertNode(n, contacted = False)
    
        l = self._retrieveValue(info_hash)
        if len(l) > 0:
            return {'values' : l, "id": self.node.id,
                    "token" : self.tokensHandler.tokenToSend(info_hash, *_krpc_sender)}
        else:
            nodes = self.table.findNodes(info_hash)
            nodes = map(lambda node: node.senderDict(), nodes)
            return {'nodes' : encodeNodes(nodes), "id": self.node.id,
                    "token" : self.tokensHandler.tokenToSend(info_hash, *_krpc_sender)}

    def krpc_announce_peer(self, id, info_hash, port, token, _krpc_sender, **kwargs):
        """Incoming RPC: got announce_peer"""
        if len(id) != 20:
            raise KrpcProtocolError("invalid id length: %d" % len(id))
        if len(info_hash) != 20:
            raise KrpcProtocolError("invalid info_hash length: %d" % len(info_hash))
        if not isinstance(port, int):
            try:
                port = int(port)
            except:
                raise KrpcProtocolError("invalid port")
        if not self.tokensHandler.checkToken(token, info_hash, *_krpc_sender):
            raise KrpcProtocolError("Got invalid token")

        # TODO: add a limit:  maximum 3 info_hash announces per peer
        ip = _krpc_sender[0]
        if ip not in self.announce:
            self.announce[ip] = []
        if info_hash not in self.announce[ip]:
            self.announce[ip].append(info_hash)
        if len(self.announce[ip]) > 3:
            raise KrpcGenericError("I only allow 3 infohash announces per peer!")

        self._storeValue(info_hash, encodePeer((_krpc_sender[0], port)))
        n = self.Node().init(id, *_krpc_sender)
        self.insertNode(n, contacted = False)
        return {"id" : self.node.id}

    def sendPing(self, node, callback=None):
        """
        Ping a node
        """        
        def _pongHandler(dict, node=node, table=self.table, callback=callback):
            _krpc_sender = dict['_krpc_sender']
            id = dict['rsp']['id']
            if len(id) == 20:
                n = self.Node().init(dict['rsp']['id'], *_krpc_sender)
                table.insertNode(n)
                if callback:
                    callback()
        def _defaultPong(err, node=node, table=self.table, callback=callback):
            table.nodeFailed(node)
            if callback:
                callback()
        try:
            df = node.ping(self.node.id)
        except KrpcGenericError:
            _defaultPong()
        else:
            df.addCallbacks(_pongHandler,_defaultPong)

    def findNode(self, id, callback, errback = None):
        """
        Returns the the k closest nodes to that ID from the global table
        """
        # get K nodes out of local table/cache
        nodes = self.table.findNodes(id)
        d = Deferred()
        if errback:
            d.addCallbacks(callback, errback)
        else:
            d.addCallback(callback)

        # create our search state
        state = FindNode(self, id, d.callback)
        self.rawserver.add_task(state.goWithNodes, 0, [nodes])

    def getPeers(self, key, callback, searchlocal = True, donecallback = None):
        """
        Get Value from global table
        """
        if not hasattr(self, "store"):
            self.rawserver.add_task(self.getPeers, 3, [key, callback, searchlocal, donecallback])
            return
        
        # get locals
        if searchlocal:
            l = self._retrieveValue(key)
            if len(l) > 0:
                l = decodePeers(l)
                self.rawserver.add_task(callback, 0, [l])
        else:
            l = []

        # create our search state
        nodes = self.table.findNodes(key)
        state = GetValue(self, key, callback, "getPeers", donecallback)
        self.rawserver.add_task(state.goWithNodes, 0, [nodes, l])
    
    def announcePeer(self, key, value, callback=None):
        """
        Store Value in global table
        """
        def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
            if not response:
                # default callback
                def _storedValueHandler(sender):
                    pass
                response=_storedValueHandler
            action = StoreValue(self, key, value, response, "announcePeer")
            self.rawserver.add_task(action.goWithNodes, 0, [nodes])
            
        # this call is asynch
        self.findNode(key, _storeValueForKey)
              

    def getPeersAndAnnounce(self, key, value, callback, searchlocal = True):
        """
        Get value and store it
        """
        def doneCallback(nodes, key = key, value = value):
            action = StoreValue(self, key, value, None, "announcePeer")
            self.rawserver.add_task(action.goWithNodes, 0, [nodes])
        self.getPeers(key, callback, searchlocal, doneCallback)
